
pacman::p_load(tidyverse,
               fst,
               fs,
               ggthemes,
               ggrepel,
               yardstick,
               lubridate,
               cowplot
)

df_all_models <- data.frame()

for(i in 1:NUM_MODELS) { 
  
  print(paste("Running model", i))
  
  TRAINING_START_QTR <- V_TRAINING_START_QTR[i]
  TRAINING_END_QTR <- V_TRAINING_END_QTR[i]
  TEST_START_QTR <- V_TEST_START_QTR[i]
  TEST_END_QTR <- V_TEST_END_QTR[i]
  
  TRAINING_SUFFIX <- paste0("Train", 
                            year(ymd(TRAINING_START_QTR)), 
                            "Q", quarter(ymd(TRAINING_START_QTR)), 
                            year(ymd(TRAINING_END_QTR)), 
                            "Q", 
                            quarter(ymd(TRAINING_END_QTR)))
  
  TEST_SUFFIX <- paste0("Test", 
                        year(ymd(TEST_START_QTR)), 
                        "Q", quarter(ymd(TEST_START_QTR)), 
                        year(ymd(TEST_END_QTR)), 
                        "Q", quarter(ymd(TEST_END_QTR))) 
  
  SUFFIX_FITTED_ <- paste0(TRAINING_SUFFIX, "_", TEST_SUFFIX, "_", RAND_NO_CID_SMALLEST_LARGEST, "_", MODEL_METRIC) 
  
  df_one_model <- read_fst(paste0("../../data/pipeline_outputs/", SPECIAL_SUFFIX, 
                                  "/", "fitted_merged_", SUFFIX_FITTED_, "/", 
                                  "test_fitted_CRA.fst")) %>% 
    filter(Income_Level %in% 1:4) %>% 
    mutate(
      d_Income_Level = case_when(
        Income_Level %in% 1:2 ~ 0,
        Income_Level %in% 3:4 ~ 1,
        TRUE ~ NA_real_
      )
    )
  
  df_all_models <- bind_rows(df_all_models, df_one_model) %>% 
    select(-Income_Level)
}

rm(df_one_model)  

Determine_Data_Available <- function(model) {
  
  path_model <- paste0("../../data/pipeline_outputs/", SPECIAL_SUFFIX, "/", "fitted_", model, 
                       "_", SUFFIX_FITTED_, "/") 
  
  dir_data <- dir_ls(path_model)
  
  !is_empty(dir_data)
  
}

v_existing_models <- map_lgl(V_POTENTIAL_MODELS, ~ Determine_Data_Available(.))

v_models <- V_POTENTIAL_MODELS[v_existing_models]

if("riskscore" %in% v_models) {
  
  df_all_models <- df_all_models %>% 
    mutate(
      riskscore = 850 - riskscore
    )
  
} 

Prepare_Plotting_Data <- function(majority_or_minority, thick_or_thin) {
  
  if(! majority_or_minority %in% c("both", "majority", "minority")) (
    stop("majority_or_minority must be one of 'both', 'majority', or 'minority")
  )
  
  if(! thick_or_thin %in% c("both", "thick", "thin")) (
    stop("thick_or_thin must be one of 'both', 'thick', or 'thin")
  )
  
  cra_subset <- switch(majority_or_minority,
                       "both" = c(0, 1),
                       "majority" = 1,
                       "minority" = 0
  )
  
  thickness <- switch(thick_or_thin,
                      "both" = c(0, 1),
                      "thick" = 1,
                      "thin" = 0
  )
  
  df_pivoted <- df_all_models %>%
    filter(d_Income_Level %in% cra_subset,
           is_thick %in% thickness) %>% 
    select(- cid, - qtr) %>% 
    pivot_longer(., cols = v_models, names_to = "Model", values_to = "fitted_value")
  
  df_yearly <- df_pivoted %>% 
    group_by(Year, Model) %>% 
    nest() %>% 
    mutate(
      roc_auc = map_dbl(data, ~ roc_auc(., truth = t_default, estimate = "fitted_value", event_level = "second") %>% pull(.estimate))
    ) %>% 
    select(-data) %>% 
    ungroup()
  
  df_overall <- df_pivoted %>% 
    group_by(Model) %>% 
    nest() %>% 
    mutate(
      roc_auc = map_dbl(data, ~ roc_auc(., truth = t_default, estimate = "fitted_value", event_level = "second") %>% pull(.estimate))
    ) %>% 
    select(-data) %>% 
    ungroup()
  
  return(list(
    "df_yearly" =  df_yearly, 
    "df_overall" = df_overall)
  )
  
}

Make_Plot <- function(df_plotting_data) {
  
  df_models <- get(df_plotting_data) %>% 
    pluck("df_yearly") %>% 
    mutate(
      Model = factor(Model, levels = c("xgb", "riskscore", "logistic"), labels = c("XGB", "riskscore", "Logistic"))
    )
  
  xgb_point_2008_ <- df_models %>% 
    filter(Model == "XGB", Year == 2008) %>% 
    pull(roc_auc) %>% 
    sum(., .045)
  
  logistic_point_2008_ <- df_models %>% 
    filter(Model == "Logistic", Year == 2008) %>% 
    pull(roc_auc) %>% 
    sum(., -.0075)
  
  df_overall <- get(df_plotting_data) %>% 
    pluck("df_overall") %>% 
    mutate(
      Model = factor(Model, levels = c("xgb", "riskscore", "logistic"), labels = c("XGB", "riskscore", "Logistic"))
    ) %>% 
    filter(Model != "riskscore") %>% 
    mutate(
      x_pos = 2008,
      y_pos = case_when(Model == "XGB"  ~ xgb_point_2008_,
                        Model == "Logistic" ~ logistic_point_2008_)
    )
  
  ggplot(filter(df_models, Model != "riskscore"), aes(x = Year, y = roc_auc, color = Model)) +
    scale_color_ptol() +
    geom_point(key_glyph = "path") +
    geom_line(size = 1.5, key_glyph = "path") +
    theme_minimal() +
    ylim(c(.75, .90)) +
    theme(
      text = element_text(size = 20, family = "Avenir"),
      panel.grid = element_blank(),
      axis.title.x = element_text(vjust = -1),
      axis.title.y = element_text(vjust = 1),
      legend.position = "bottom",
      legend.background = element_rect(),
      legend.box.just = "center",
      legend.key.width = unit(1, "cm"),
      legend.key.height = unit(.55, "cm")
    ) +
    guides(lty = guide_legend(override.aes = list(size=1))) +
    labs(
      x = "Year",
      y = "ROC AUC"
    ) +
    geom_text(data = df_overall, aes(x = x_pos,
                                     y = y_pos,
                                     label = paste0(Model, " (All Years): ", round(roc_auc, 3))),
              show.legend = FALSE, size = 6)
  
  
}


# Thick and Thin ----------------------------------------------------------

tic()
df_plotting_all <- Prepare_Plotting_Data(
  majority_or_minority = "both", 
  thick_or_thin = "both"
)
toc()

tic()
df_plotting_majority <- Prepare_Plotting_Data(
  majority_or_minority = "majority", 
  thick_or_thin = "both"
)
toc()

df_plotting_minority <- Prepare_Plotting_Data(
  majority_or_minority = "minority", 
  thick_or_thin = "both"
)

p_overall <- Make_Plot("df_plotting_all")
p_majority <- Make_Plot("df_plotting_majority") 
p_minority <- Make_Plot("df_plotting_minority")

Save_ROC_AUC_Data <- function(l_data) {
  
  
  output_name <- str_extract(l_data, "all|majority|minority")
  
  df_yearly <- get(l_data) %>% 
    pluck("df_yearly")
  
  df_overall <- get(l_data) %>% 
    pluck("df_overall")
  
  path_yearly <- str_glue("../../data/pipeline_outputs/{SPECIAL_SUFFIX}/cra_plots_{SUFFIX_FITTED_}/{output_name}_yearly_roc_auc.csv")
  path_overall <- str_glue("../../data/pipeline_outputs/{SPECIAL_SUFFIX}/cra_plots_{SUFFIX_FITTED_}/{output_name}_overall_roc_auc.csv")
  
  write_csv(df_yearly, path_yearly)
  write_csv(df_overall, path_overall)
  
}

walk(c("df_plotting_all", "df_plotting_majority", "df_plotting_minority"), ~ Save_ROC_AUC_Data(.))

ggsave(plot = p_overall, 
       filename = paste0("../../data/pipeline_outputs/", SPECIAL_SUFFIX, "/", "cra_plots_", SUFFIX_FITTED_, "/",
                         "CRA_roc_auc_overall.png"),
       width = 7, height = 7, units = "in", dpi = 1000
)

ggsave(plot = p_majority, 
       filename = paste0("../../data/pipeline_outputs/", SPECIAL_SUFFIX, "/", "cra_plots_", SUFFIX_FITTED_, "/",
                         "CRA_roc_auc_majority.png"),
       width = 7, height = 7, units = "in", dpi = 1000
)

ggsave(plot = p_minority, 
       filename = paste0("../../data/pipeline_outputs/", SPECIAL_SUFFIX, "/", "cra_plots_", SUFFIX_FITTED_, "/",
                         "CRA_roc_auc_minority.png"),
       width = 7, height = 7, units = "in", dpi = 1000
)

p_to_save <- plot_grid(p_overall, NULL, p_majority, p_minority, labels = c("A: Overall", "", "B: Non-LMI", "C: LMI"), nrow = 2, ncol = 2)

